{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[download this notebook here](https://github.com/HumanCompatibleAI/imitation/blob/master/docs/tutorials/5_train_preference_comparisons.ipynb)\n",
    "# Learning a Reward Function using Preference Comparisons\n",
    "\n",
    "The preference comparisons algorithm learns a reward function by comparing trajectory segments to each other."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To set up the preference comparisons algorithm, we first need to set up a lot of its internals beforehand:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random\n",
    "from imitation.algorithms import preference_comparisons\n",
    "from imitation.rewards.reward_nets import BasicRewardNet\n",
    "from imitation.util.networks import RunningNorm\n",
    "from imitation.util.util import make_vec_env\n",
    "from imitation.policies.base import FeedForward32Policy, NormalizeFeaturesExtractor\n",
    "import gymnasium as gym\n",
    "from stable_baselines3 import PPO\n",
    "import numpy as np\n",
    "\n",
    "rng = np.random.default_rng(0)\n",
    "\n",
    "venv = make_vec_env(\"Pendulum-v1\", rng=rng)\n",
    "\n",
    "reward_net = BasicRewardNet(\n",
    "    venv.observation_space, venv.action_space, normalize_input_layer=RunningNorm\n",
    ")\n",
    "\n",
    "fragmenter = preference_comparisons.RandomFragmenter(\n",
    "    warning_threshold=0,\n",
    "    rng=rng,\n",
    ")\n",
    "gatherer = preference_comparisons.SyntheticGatherer(rng=rng)\n",
    "preference_model = preference_comparisons.PreferenceModel(reward_net)\n",
    "reward_trainer = preference_comparisons.BasicRewardTrainer(\n",
    "    preference_model=preference_model,\n",
    "    loss=preference_comparisons.CrossEntropyRewardLoss(),\n",
    "    epochs=3,\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "\n",
    "# Several hyperparameters (reward_epochs, ppo_clip_range, ppo_ent_coef,\n",
    "# ppo_gae_lambda, ppo_n_epochs, discount_factor, use_sde, sde_sample_freq,\n",
    "# ppo_lr, exploration_frac, num_iterations, initial_comparison_frac,\n",
    "# initial_epoch_multiplier, query_schedule) used in this example have been\n",
    "# approximately fine-tuned to reach a reasonable level of performance.\n",
    "agent = PPO(\n",
    "    policy=FeedForward32Policy,\n",
    "    policy_kwargs=dict(\n",
    "        features_extractor_class=NormalizeFeaturesExtractor,\n",
    "        features_extractor_kwargs=dict(normalize_class=RunningNorm),\n",
    "    ),\n",
    "    env=venv,\n",
    "    seed=0,\n",
    "    n_steps=2048 // venv.num_envs,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.01,\n",
    "    learning_rate=2e-3,\n",
    "    clip_range=0.1,\n",
    "    gae_lambda=0.95,\n",
    "    gamma=0.97,\n",
    "    n_epochs=10,\n",
    ")\n",
    "\n",
    "trajectory_generator = preference_comparisons.AgentTrainer(\n",
    "    algorithm=agent,\n",
    "    reward_fn=reward_net,\n",
    "    venv=venv,\n",
    "    exploration_frac=0.05,\n",
    "    rng=rng,\n",
    ")\n",
    "\n",
    "pref_comparisons = preference_comparisons.PreferenceComparisons(\n",
    "    trajectory_generator,\n",
    "    reward_net,\n",
    "    num_iterations=5,  # Set to 60 for better performance\n",
    "    fragmenter=fragmenter,\n",
    "    preference_gatherer=gatherer,\n",
    "    reward_trainer=reward_trainer,\n",
    "    fragment_length=100,\n",
    "    transition_oversampling=1,\n",
    "    initial_comparison_frac=0.1,\n",
    "    allow_variable_horizon=False,\n",
    "    initial_epoch_multiplier=4,\n",
    "    query_schedule=\"hyperbolic\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can start training the reward model. Note that we need to specify the total timesteps that the agent should be trained and how many fragment comparisons should be made."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pref_comparisons.train(\n",
    "    total_timesteps=5_000,\n",
    "    total_comparisons=200,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After we trained the reward network using the preference comparisons algorithm, we can wrap our environment with that learned reward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from imitation.rewards.reward_wrapper import RewardVecEnvWrapper\n",
    "\n",
    "learned_reward_venv = RewardVecEnvWrapper(venv, reward_net.predict_processed)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we train an agent that sees only the shaped, learned reward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "learner = PPO(\n",
    "    seed=0,\n",
    "    policy=FeedForward32Policy,\n",
    "    policy_kwargs=dict(\n",
    "        features_extractor_class=NormalizeFeaturesExtractor,\n",
    "        features_extractor_kwargs=dict(normalize_class=RunningNorm),\n",
    "    ),\n",
    "    env=learned_reward_venv,\n",
    "    batch_size=64,\n",
    "    ent_coef=0.01,\n",
    "    n_epochs=10,\n",
    "    n_steps=2048 // learned_reward_venv.num_envs,\n",
    "    clip_range=0.1,\n",
    "    gae_lambda=0.95,\n",
    "    gamma=0.97,\n",
    "    learning_rate=2e-3,\n",
    ")\n",
    "learner.learn(1_000)  # Note: set to 100_000 to train a proficient expert"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then we can evaluate it using the original reward."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stable_baselines3.common.evaluation import evaluate_policy\n",
    "\n",
    "n_eval_episodes = 10\n",
    "reward_mean, reward_std = evaluate_policy(learner.policy, venv, n_eval_episodes)\n",
    "reward_stderr = reward_std / np.sqrt(n_eval_episodes)\n",
    "print(f\"Reward: {reward_mean:.0f} +/- {reward_stderr:.0f}\")"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "439158cd89905785fcc749928062ade7bfccc3f087fab145e5671f895c635937"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
